#!/bin/bash
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export CUDA_VISIBLE_DEVICES=0
log_dir="./results/log"
log_file="$log_dir/inference_0.out"
mkdir -p "$log_dir"
exec >"$log_file" 2>&1

modelabbrs=("llama-3.1-8b-instruction" "gemma-2-9b-it" "mistral-7b-Instruct-v0.3")
train_datasets=("gsm8k-plus-mini")
test_datasets=("gsm8k-plus-mini")
subset_size=100
exp_num=10
metric="cosine_similarity"
ks=(0)
methods=("knn")
embs=("all-roberta-large-v1")
permutations=(1)
freq=32
decoding="greedy"
total_num=0
for train_dataset in "${train_datasets[@]}"; do
    for test_dataset in "${test_datasets[@]}"; do
        for modelabbr in "${modelabbrs[@]}"; do
            for emb in "${embs[@]}"; do
                for k in "${ks[@]}"; do
                    for method in "${methods[@]}"; do
                        if [ "$test_dataset" != "$train_dataset" ]; then
                            continue
                        fi
                        if [[ "$method" == *"knn"* || "$method" == *"k_means"* ]]; then
                            exp_num_method=1
                        else
                            exp_num_method=$exp_num
                        fi
                        
                        total_num=$((total_num + exp_num_method))
                    done
                done
            done
        done
    done
done

echo "Total number of runs: $total_num"
                        
target_num=-1

current_num=0
for train_dataset in "${train_datasets[@]}"; do
    for test_dataset in "${test_datasets[@]}"; do
        for modelabbr in "${modelabbrs[@]}"; do
            # change modelname to locate the model correctly
            modelname="/home/amax/exp/huggingface/transformers/${modelabbr}"
            for emb in "${embs[@]}"; do
                for k in "${ks[@]}"; do
                    for method in "${methods[@]}"; do
                        if [[ "$method" == *"knn"* || "$method" == *"k_means"* ]]; then
                            exp_num_method=1
                        else
                            exp_num_method=$exp_num
                        fi
                        if [ "$test_dataset" != "$train_dataset" ]; then
                            continue
                        fi
                        for ((i=0; i<exp_num_method; i++)); do
                            for permutation in "${permutations[@]}"; do
                                ((current_num++))
                                if [[ "$current_num" -ge "$target_num" ]]; then        
                                    echo "Current count: $current_num / $total_num"
                                    DECODING_FLAGS=$(python decoding_args_helper.py $decoding)
                                    echo "Decoding Flags: $DECODING_FLAGS"
                                    python -u fast_inference.py \
                                        --model_path "$modelname" \
                                        --train_dataset "$train_dataset" \
                                        --test_dataset "$test_dataset" \
                                        --max_new_tokens 1024 \
                                        --prompt_template_style "$test_dataset" \
                                        --subset_size "$subset_size" \
                                        --k "$k" \
                                        --exp_num "$i" \
                                        --method "$method" \
                                        --emb "$emb" \
                                        --metric "$metric" \
                                        --permutation "$permutation" \
                                        $DECODING_FLAGS --freq "$freq" \
                                        --apply_chat_template
                                fi
                            done
                        done
                    done
                done
            done
        done
    done
done